[ROCm] Fix allreduce + RMSNorm fusion pattern matchin#41767
Open
rbrugaro-amd wants to merge 3 commits intovllm-project:mainfrom
Open
[ROCm] Fix allreduce + RMSNorm fusion pattern matchin#41767rbrugaro-amd wants to merge 3 commits intovllm-project:mainfrom
rbrugaro-amd wants to merge 3 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request updates the allreduce_rms_fusion pass to initialize the residual tensor with zeros instead of uninitialized memory. Additionally, it modifies the RMSNorm forward pass in layernorm.py to conditionally pass the variance_size_override argument only when it is not None. I have no feedback to provide as there were no review comments.
|
I tested this with Without this patch and Without this patch and With this patch and |
| self.weight.data if self.pass_weight else None, | ||
| self.variance_epsilon, | ||
| self.variance_size_override, | ||
| *( |
Collaborator
There was a problem hiding this comment.
Let's just fix the patterns in the pass instead
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes two issues that broke the
allreduce + RMSNormfusion pass introduced in #37646, caused by subsequent refactoring in #36823torch.empty_like→torch.zeros_likeinAiterAllreduceFusedRMSNormPattern._replacement(allreduce_rms_fusion.py):The fused allreduce+rmsnorm kernel always adds
res_inp; usingempty_likeleaves undefined values that corrupt outputs. Changed tozeros_likeso the add is a no-op when residual is freshly created.Conditional
variance_size_overrideargument inRMSNorm.forward_native(layernorm.py):After the IR refactoring,
ir.ops.rms_normandir.ops.fused_add_rms_norm.maybe_inplacewere unconditionally passedself.variance_size_override(even whenNone). This produced 4-argument calls in the FX graph, but the fusion patterns expect 3 arguments. The mismatch prevented pattern matching entirely. Fixed by conditionally unpackingvariance_size_overrideonly when it is notNone.Testing
Tested with
Kimi-K2-Thinking-MXFP4on 4x MI355X (TP=4)0.20.1rc1.dev153+gcfd2573f2(base commitcfd2573f2)amd-aiter 0.1.12.post2.dev126+g033d8b9db2.10.0+git8514f05Fusion pass results (confirmed via
VLLM_DEBUG_DUMP_PATHgraph dumps and custom logging):all_reduce_fusion_pass: 244 pattern matches across 2 compile ranges (122 per range)mla_dual_rms_norm_fusion_pass: 183 matchesfused_add_rms_norm=aiterimplementation selected